import collections
from typing import Optional

import d4rl
import gym
import numpy as np
from tqdm import tqdm

Batch = collections.namedtuple(
    'Batch',
    ['observations', 'actions', 'rewards', 'masks', 'next_observations', 'is_experts'])


def split_into_trajectories(observations, actions, rewards, masks, dones_float,
                            next_observations, is_experts):
    trajs = [[]]

    for i in tqdm(range(len(observations))):
        trajs[-1].append((observations[i], actions[i], rewards[i], masks[i],
                          dones_float[i], next_observations[i], is_experts[i]))
        if dones_float[i] == 1.0 and i + 1 < len(observations):
            trajs.append([])

    return trajs


def merge_trajectories(trajs):
    observations = []
    actions = []
    rewards = []
    masks = []
    dones_float = []
    next_observations = []
    is_experts = []

    for traj in trajs:
        for (obs, act, rew, mask, done, next_obs, is_expert) in traj:
            observations.append(obs)
            actions.append(act)
            rewards.append(rew)
            masks.append(mask)
            dones_float.append(done)
            next_observations.append(next_obs)
            is_experts.append(is_expert)

    return np.stack(observations), np.stack(actions), np.stack(
        rewards), np.stack(masks), np.stack(dones_float), np.stack(
            next_observations), np.stack(is_experts)

def merge_datasets(dataset1, dataset2):
    observations = np.concatenate((dataset1.observations, dataset2.observations), axis = 0)
    actions = np.concatenate((dataset1.actions, dataset2.actions), axis = 0)
    rewards = np.concatenate((dataset1.rewards, dataset2.rewards), axis = 0)
    masks = np.concatenate((dataset1.masks, dataset2.masks), axis = 0)
    dones_float = np.concatenate((dataset1.dones_float, dataset2.dones_float), axis = 0)
    next_observations = np.concatenate((dataset1.next_observations, dataset2.next_observations), axis = 0)
    is_experts = np.concatenate((dataset1.is_experts, dataset2.is_experts), axis = 0)
    size = dataset1.size + dataset2.size
    return Dataset(observations, actions, rewards, masks, dones_float, next_observations, is_experts, size)


class Dataset(object):
    def __init__(self, observations: np.ndarray, actions: np.ndarray,
                 rewards: np.ndarray, masks: np.ndarray,
                 dones_float: np.ndarray, next_observations: np.ndarray,
                 is_experts: np.ndarray,
                 size: int):
        self.observations = observations
        self.actions = actions
        self.rewards = rewards
        self.masks = masks
        self.dones_float = dones_float
        self.next_observations = next_observations
        self.is_experts = is_experts
        self.size = size
    
    def merge(self, dataset):
        self.observations = np.concatenate((self.observations, dataset.observations), axis = 0)
        self.actions = np.concatenate((self.actions, dataset.actions), axis = 0)
        self.rewards = np.concatenate((self.rewards, dataset.rewards), axis = 0)
        self.masks = np.concatenate((self.masks, dataset.masks), axis = 0)
        self.dones_float = np.concatenate((self.dones_float, dataset.dones_float), axis = 0)
        self.next_observations = np.concatenate((self.next_observations, dataset.next_observations), axis = 0)
        self.is_experts = np.concatenate((self.is_experts, dataset.is_experts), axis = 0)
        self.size = self.size + dataset.size

    def add_trajectory(self, traj):
        self.dones_float[-1] = 0.
        for i, (obs, act, rew, mask, done, next_obs, is_expert) in enumerate(traj):
            self.observations = np.concatenate((self.observations, np.asarray([obs])))
            self.actions = np.concatenate((self.actions, np.asarray([act])))
            self.rewards = np.concatenate((self.rewards, np.asarray([rew])))
            self.masks = np.concatenate((self.masks, np.asarray([mask])))
            self.dones_float = np.concatenate((self.dones_float, np.asarray([done])))
            self.next_observations = np.concatenate((self.next_observations, np.asarray([next_obs])))
            self.is_experts =np.concatenate((self.is_experts,  np.asarray([is_expert])))
            self.size += 1
        assert self.size == len(self.observations)
    
    def add_trajectories(self, trajs):
        for traj in trajs:
            self.add_trajectory(traj)


    def sample(self, batch_size: int) -> Batch:
        indx = np.random.randint(self.size, size=batch_size).astype(int)
        return Batch(observations=self.observations[indx],
                     actions=self.actions[indx],
                     rewards=self.rewards[indx],
                     masks=self.masks[indx],
                     next_observations=self.next_observations[indx],
                     is_experts=self.is_experts[indx])

    def sample_trajectories(self, num: Optional[int] = 1, return_dataset = False):
        start_ids = [0] + (np.where(self.dones_float == 1)[0] + 1).tolist()[:-1]
        end_ids = (np.where(self.dones_float == 1)[0] + 1).tolist()
        start_end_ids = np.asarray(list(zip(start_ids, end_ids)))
        if num is not None:
            assert len(start_end_ids) >= num
        if len(start_end_ids) < num or num is None:
            selected_ids = np.arange(self.size)
            selected_start_end_ids = [0, self.size]
        else:
            selected_start_end_ids_ids = np.random.choice(np.arange(len(start_end_ids)), num).astype(int)
            selected_start_end_ids = start_end_ids[selected_start_end_ids_ids]
            selected_ids = np.concatenate([np.arange(*start_end_ids) for start_end_ids in selected_start_end_ids])

        if return_dataset:
            return Dataset(
                self.observations[selected_ids],
                self.actions[selected_ids],
                self.rewards[selected_ids],
                self.masks[selected_ids],
                self.dones_float[selected_ids],
                self.next_observations[selected_ids],
                self.is_experts[selected_ids],
                len(selected_ids)
            )
        else:
            trajs = []
            for (start_id, end_id) in selected_start_end_ids:
                trajs.append(
                    list(
                        zip(
                            self.observations[start_id:end_id], 
                            self.actions[start_id:end_id],
                            self.rewards[start_id:end_id],
                            self.masks[start_id:end_id],
                            self.dones_float[start_id:end_id],
                            self.next_observations[start_id:end_id],
                            self.is_experts[start_id:end_id]
                        )
                        )
                )
            return trajs
 

class D4RLDataset(Dataset):
    def __init__(self,
                 env: gym.Env,
                 clip_to_eps: bool = True,
                 eps: float = 1e-5,
                 **kwargs
                 ):
        dataset = d4rl.qlearning_dataset(env)

        if clip_to_eps:
            lim = 1 - eps
            dataset['actions'] = np.clip(dataset['actions'], -lim, lim)

        dones_float = np.zeros_like(dataset['rewards'])

        for i in range(len(dones_float) - 1):
            if np.linalg.norm(dataset['observations'][i + 1] -
                              dataset['next_observations'][i]
                              ) > 1e-6 or dataset['terminals'][i] == 1.0:
                dones_float[i] = 1
            else:
                dones_float[i] = 0

        dones_float[-1] = 1

        is_expert = 1. if 'expert' in env.unwrapped.spec.id else 0.

        super().__init__(dataset['observations'].astype(np.float32),
                         actions=dataset['actions'].astype(np.float32),
                         rewards=dataset['rewards'].astype(np.float32),
                         masks=1.0 - dataset['terminals'].astype(np.float32),
                         dones_float=dones_float.astype(np.float32),
                         next_observations=dataset['next_observations'].astype(
                             np.float32),
                        is_experts=np.zeros(len(dataset['observations'])) + is_expert,
                         size=len(dataset['observations']),
                         )


class ReplayBuffer(Dataset):
    def __init__(self, observation_space: gym.spaces.Box, action_dim: int,
                 capacity: int):

        observations = np.empty((capacity, *observation_space.shape),
                                dtype=observation_space.dtype)
        actions = np.empty((capacity, action_dim), dtype=np.float32)
        rewards = np.empty((capacity, ), dtype=np.float32)
        masks = np.empty((capacity, ), dtype=np.float32)
        dones_float = np.empty((capacity, ), dtype=np.float32)
        next_observations = np.empty((capacity, *observation_space.shape),
                                     dtype=observation_space.dtype)
        is_experts = np.zeros((capacity, ), dtype=np.float32)

        super().__init__(observations=observations,
                         actions=actions,
                         rewards=rewards,
                         masks=masks,
                         dones_float=dones_float,
                         next_observations=next_observations,
                         is_experts=is_experts,
                         size=0)

        self.size = 0

        self.insert_index = 0
        self.capacity = capacity

    def initialize_with_dataset(self, dataset: Dataset,
                                num_samples: Optional[int]):
        assert self.insert_index == 0, 'Can insert a batch online in an empty replay buffer.'

        dataset_size = len(dataset.observations)

        if num_samples is None:
            num_samples = dataset_size
        else:
            num_samples = min(dataset_size, num_samples)
        assert self.capacity >= num_samples, 'Dataset cannot be larger than the replay buffer capacity.'

        if num_samples < dataset_size:
            perm = np.random.permutation(dataset_size)
            indices = perm[:num_samples]
        else:
            indices = np.arange(num_samples)

        self.observations[:num_samples] = dataset.observations[indices]
        self.actions[:num_samples] = dataset.actions[indices]
        self.rewards[:num_samples] = dataset.rewards[indices]
        self.masks[:num_samples] = dataset.masks[indices]
        self.dones_float[:num_samples] = dataset.dones_float[indices]
        self.next_observations[:num_samples] = dataset.next_observations[
            indices]
        self.is_experts[:num_samples] = dataset.is_experts[indices]

        self.insert_index = num_samples
        self.size = num_samples

    def insert(self, observation: np.ndarray, action: np.ndarray,
               reward: float, mask: float, done_float: float, 
               next_observation: np.ndarray, is_experts: np.ndarray):
        self.observations[self.insert_index] = observation
        self.actions[self.insert_index] = action
        self.rewards[self.insert_index] = reward
        self.masks[self.insert_index] = mask
        self.dones_float[self.insert_index] = done_float
        self.next_observations[self.insert_index] = next_observation
        self.is_experts[self.insert_index] = is_experts

        self.insert_index = (self.insert_index + 1) % self.capacity
        self.size = min(self.size + 1, self.capacity)
